from __future__ import print_function
import argparse
from re import L
from skimage import transform
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data.dataloader import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms.functional import _is_numpy_image
from torchvision.utils import save_image
# for jupyter style running
class Args:
def __init__(self):
self.batch_size = 128
self.cuda = True
self.log_interval = 10
self.epochs = 10
self.seed = 12345
args = Args()
torch.manual_seed(args.seed)
device = torch.device("cuda" if args.cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
# neural data
import nnfabrik
from nnfabrik import builder
import numpy as np
import pickle
import os
from os import listdir
from os.path import isfile, join
import matplotlib.pyplot as plt
import nnvision
basepath = '/home/data/monkey/toliaslab/CSRF19_V1'
neuronal_data_path = os.path.join(basepath, 'neuronal_data/')
neuronal_data_files = [neuronal_data_path+f for f in listdir(neuronal_data_path) if isfile(join(neuronal_data_path, f))]
image_file = os.path.join(basepath, 'images/CSRF19_V1_images.pickle')
image_cache_path = os.path.join(basepath, 'images/individual')
dataset_fn = 'nnvision.datasets.monkey_static_loader'
dataset_config = dict(dataset='CSRF19_V1',
neuronal_data_files=neuronal_data_files,
image_cache_path=image_cache_path,
crop=0,
subsample=1,
seed=1000,
time_bins_sum=6,
batch_size=128,)
dataloaders = builder.get_data(dataset_fn, dataset_config)
some_image = dataloaders["train"][list(dataloaders["train"].keys())[11]].dataset[:].inputs[0,0,::].cpu().numpy()
plt.imshow(some_image, cmap='gray')
<matplotlib.image.AxesImage at 0x7f66d2aa46a0>
# pick the first session
first_session_id = list(dataloaders['train'].keys())[0]
train_loader_first_session = dataloaders['train'][first_session_id]
train_loader = train_loader_first_session
# train dataset fixed.
# test dataset to be fixed now
# test data batching is done differently remember -- each batch in the test set is purely repeats.
# hence from each test batch, pick only one image tensor
# start with the first session
test_loader_first_session = dataloaders['test'][first_session_id]
testset_images = [inputs[0] for inputs, targets in test_loader_first_session]
test_loader = DataLoader(testset_images)
# construct a data (image) resizer
resizer = transforms.Resize(size=(28, 28))
import matplotlib.pyplot as plt
# sample some resized images and take a look
for batch_idx, (data, _) in enumerate(train_loader):
# plt.imshow(data[0,0,::])
plt.imshow(data[0].permute(1, 2, 0))
plt.show()
data_resized = resizer(data)
plt.figure(figsize=(5,5))
plt.imshow(data_resized[0].permute(1, 2, 0))
plt.show()
if batch_idx > 5:
break
import matplotlib.pyplot as plt
# sample some resized images and take a look
for batch_idx, (data, _) in enumerate(train_loader):
# plt.imshow(data[0,0,::])
plt.imshow(data[0].permute(1, 2, 0))
plt.show()
data_resized = resizer(data)
plt.figure(figsize=(2,2))
plt.imshow(data_resized[0].permute(1, 2, 0))
plt.show()
if batch_idx > 5:
break
import matplotlib.pyplot as plt
# sample some resized images and take a look
for batch_idx, (data, _) in enumerate(train_loader):
# plt.imshow(data[0,0,::])
plt.imshow(data[0].permute(1, 2, 0))
plt.show()
data_resized = resizer(data)
# plt.figure(figsize=(2,2))
plt.imshow(data_resized[0].permute(1, 2, 0))
plt.show()
if batch_idx > 5:
break
import matplotlib.pyplot as plt
# sample some resized images and take a look
for batch_idx, (data, _) in enumerate(train_loader):
# plt.imshow(data[0,0,::])
plt.imshow(data[0].permute(1, 2, 0))
plt.show()
data_resized = resizer(data)
# plt.figure(figsize=(2,2))
plt.imshow(data_resized[0].permute(1, 2, 0))
plt.show()
print(data_resized[0].shape)
if batch_idx > 5:
break
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.fc1 = nn.Linear(784, 400)
self.fc21 = nn.Linear(400, 20)
self.fc22 = nn.Linear(400, 20)
self.fc3 = nn.Linear(20, 400)
self.fc4 = nn.Linear(400, 784)
def encode(self, x):
h1 = F.relu(self.fc1(x))
return self.fc21(h1), self.fc22(h1)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return mu + eps*std
def decode(self, z):
h3 = F.relu(self.fc3(z))
return torch.sigmoid(self.fc4(h3))
def forward(self, x):
mu, logvar = self.encode(x.view(-1, 784)) # x.view(-1, 784) flattens x
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
def train(epoch):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = resizer(data)
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss = loss_function(recon_batch, data, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.item() / len(data)))
print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(train_loader.dataset)))
def test(epoch):
model.eval()
test_loss = 0
with torch.no_grad():
for i, data in enumerate(test_loader):
data = resizer(data)
data = data.to(device)
recon_batch, mu, logvar = model(data)
test_loss += loss_function(recon_batch, data, mu, logvar).item()
if i == 0:
n = min(data.size(0), 8)
comparison = torch.cat([data[:n],
recon_batch.view(1, 1, 28, 28)[:n]])
save_image(comparison.cpu(),
'results_neural_imgs/reconstruction_' + str(epoch) + '.png', nrow=n)
test_loss /= len(test_loader.dataset)
print('====> Test set loss: {:.4f}'.format(test_loss))
for epoch in range(1, args.epochs + 1):
train(epoch)
test(epoch)
with torch.no_grad():
sample = torch.randn(64, 20).to(device)
sample = model.decode(sample).cpu()
save_image(sample.view(64, 1, 28, 28),
'results_neural_imgs/sample_' + str(epoch) + '.png')
Train Epoch: 1 [0/10154 (0%)] Loss: 547.298096 Train Epoch: 1 [1280/10154 (12%)] Loss: -936.209106 Train Epoch: 1 [2560/10154 (25%)] Loss: -6559.416504 Train Epoch: 1 [3840/10154 (38%)] Loss: -9107.536133 Train Epoch: 1 [5120/10154 (50%)] Loss: -8163.457520 Train Epoch: 1 [6400/10154 (62%)] Loss: -9853.704102 Train Epoch: 1 [7680/10154 (75%)] Loss: -10343.613281 Train Epoch: 1 [8960/10154 (88%)] Loss: -11297.556641 ====> Epoch: 1 Average loss: -6985.5558 ====> Test set loss: -9589.2421 Train Epoch: 2 [0/10154 (0%)] Loss: -9470.880859 Train Epoch: 2 [1280/10154 (12%)] Loss: -8694.735352 Train Epoch: 2 [2560/10154 (25%)] Loss: -8768.881836 Train Epoch: 2 [3840/10154 (38%)] Loss: -10231.545898 Train Epoch: 2 [5120/10154 (50%)] Loss: -12052.275391 Train Epoch: 2 [6400/10154 (62%)] Loss: -12671.904297 Train Epoch: 2 [7680/10154 (75%)] Loss: -12165.340820 Train Epoch: 2 [8960/10154 (88%)] Loss: -10529.303711 ====> Epoch: 2 Average loss: -11339.1982 ====> Test set loss: -11997.5604 Train Epoch: 3 [0/10154 (0%)] Loss: -11772.670898 Train Epoch: 3 [1280/10154 (12%)] Loss: -12751.399414 Train Epoch: 3 [2560/10154 (25%)] Loss: -12693.022461 Train Epoch: 3 [3840/10154 (38%)] Loss: -14976.128906 Train Epoch: 3 [5120/10154 (50%)] Loss: -15523.180664 Train Epoch: 3 [6400/10154 (62%)] Loss: -11808.393555 Train Epoch: 3 [7680/10154 (75%)] Loss: -13146.436523 Train Epoch: 3 [8960/10154 (88%)] Loss: -14979.049805 ====> Epoch: 3 Average loss: -13495.9512 ====> Test set loss: -13222.9455 Train Epoch: 4 [0/10154 (0%)] Loss: -14326.793945 Train Epoch: 4 [1280/10154 (12%)] Loss: -15609.947266 Train Epoch: 4 [2560/10154 (25%)] Loss: -15193.702148 Train Epoch: 4 [3840/10154 (38%)] Loss: -13913.916992 Train Epoch: 4 [5120/10154 (50%)] Loss: -10424.295898 Train Epoch: 4 [6400/10154 (62%)] Loss: -17424.591797 Train Epoch: 4 [7680/10154 (75%)] Loss: -12666.552734 Train Epoch: 4 [8960/10154 (88%)] Loss: -15488.755859 ====> Epoch: 4 Average loss: -14369.8070 ====> Test set loss: -14390.1000 Train Epoch: 5 [0/10154 (0%)] Loss: -15826.942383 Train Epoch: 5 [1280/10154 (12%)] Loss: -14621.226562 Train Epoch: 5 [2560/10154 (25%)] Loss: -14064.419922 Train Epoch: 5 [3840/10154 (38%)] Loss: -13278.973633 Train Epoch: 5 [5120/10154 (50%)] Loss: -15402.059570 Train Epoch: 5 [6400/10154 (62%)] Loss: -13012.573242 Train Epoch: 5 [7680/10154 (75%)] Loss: -13511.670898 Train Epoch: 5 [8960/10154 (88%)] Loss: -15593.186523 ====> Epoch: 5 Average loss: -14889.7374 ====> Test set loss: -14629.8192 Train Epoch: 6 [0/10154 (0%)] Loss: -14264.577148 Train Epoch: 6 [1280/10154 (12%)] Loss: -18182.052734 Train Epoch: 6 [2560/10154 (25%)] Loss: -14620.293945 Train Epoch: 6 [3840/10154 (38%)] Loss: -15153.944336 Train Epoch: 6 [5120/10154 (50%)] Loss: -13352.387695 Train Epoch: 6 [6400/10154 (62%)] Loss: -16509.148438 Train Epoch: 6 [7680/10154 (75%)] Loss: -16626.496094 Train Epoch: 6 [8960/10154 (88%)] Loss: -17149.833984 ====> Epoch: 6 Average loss: -15573.0033 ====> Test set loss: -15521.5018 Train Epoch: 7 [0/10154 (0%)] Loss: -17588.066406 Train Epoch: 7 [1280/10154 (12%)] Loss: -13671.388672 Train Epoch: 7 [2560/10154 (25%)] Loss: -15189.336914 Train Epoch: 7 [3840/10154 (38%)] Loss: -16411.789062 Train Epoch: 7 [5120/10154 (50%)] Loss: -15855.462891 Train Epoch: 7 [6400/10154 (62%)] Loss: -15437.014648 Train Epoch: 7 [7680/10154 (75%)] Loss: -15987.178711 Train Epoch: 7 [8960/10154 (88%)] Loss: -14512.131836 ====> Epoch: 7 Average loss: -16046.4828 ====> Test set loss: -16437.1989 Train Epoch: 8 [0/10154 (0%)] Loss: -14688.786133 Train Epoch: 8 [1280/10154 (12%)] Loss: -16835.740234 Train Epoch: 8 [2560/10154 (25%)] Loss: -14397.198242 Train Epoch: 8 [3840/10154 (38%)] Loss: -18143.455078 Train Epoch: 8 [5120/10154 (50%)] Loss: -18137.214844 Train Epoch: 8 [6400/10154 (62%)] Loss: -16139.227539 Train Epoch: 8 [7680/10154 (75%)] Loss: -16227.043945 Train Epoch: 8 [8960/10154 (88%)] Loss: -18998.798828 ====> Epoch: 8 Average loss: -16446.7962 ====> Test set loss: -16184.1335 Train Epoch: 9 [0/10154 (0%)] Loss: -15101.893555 Train Epoch: 9 [1280/10154 (12%)] Loss: -18096.724609 Train Epoch: 9 [2560/10154 (25%)] Loss: -18203.718750 Train Epoch: 9 [3840/10154 (38%)] Loss: -15050.525391 Train Epoch: 9 [5120/10154 (50%)] Loss: -14010.612305 Train Epoch: 9 [6400/10154 (62%)] Loss: -16090.330078 Train Epoch: 9 [7680/10154 (75%)] Loss: -18070.681641 Train Epoch: 9 [8960/10154 (88%)] Loss: -15878.502930 ====> Epoch: 9 Average loss: -16798.1500 ====> Test set loss: -16609.6481 Train Epoch: 10 [0/10154 (0%)] Loss: -17282.349609 Train Epoch: 10 [1280/10154 (12%)] Loss: -18306.154297 Train Epoch: 10 [2560/10154 (25%)] Loss: -20235.498047 Train Epoch: 10 [3840/10154 (38%)] Loss: -18665.654297 Train Epoch: 10 [5120/10154 (50%)] Loss: -18450.544922 Train Epoch: 10 [6400/10154 (62%)] Loss: -17041.603516 Train Epoch: 10 [7680/10154 (75%)] Loss: -18883.472656 Train Epoch: 10 [8960/10154 (88%)] Loss: -16736.302734 ====> Epoch: 10 Average loss: -17231.8471 ====> Test set loss: -17388.9580